"""
This module contains some functions for creating feature vectors. By default, Caton pre-computes feature vectors rather than calculating them for the current dataset. The spikes used to create the feature vectors may have different sampling rates and intervals than the spikes in the current dataset, so we need to interpolate to get the appropriate feature vectors.

Two files contain the precomputed features: features.txt and timeseries.txt. They have the following format:
   timeseries.txt: N lines. Each line contains one time, measured in seconds, relative to the peak.
   features.txt: N lines. Each line contains N amplitudes: the waveform of one feature. First line is the best feature (approx. amplitude), etc.
"""
from __future__ import division, with_statement
import numpy as np
import os
from utils_misc import switch_ext,find_file_with_ext
from output import read_spk, read_clu, get_pars_from_xml2
from CEM_extensions import class_means, class_covs

feature_path = os.path.join(os.path.split(__file__)[0],"data/features.txt")
ts_path = os.path.join(os.path.split(__file__)[0],"data/timeseries.txt")

def compute_pcs(X_ns):
    """Compute principal components of X_ns
    Parameters
    ----------------
    X_ns : array. each row is a data point
    
    Returns
    -----------------
    PC_ss : array. Each row is a principal component. Sorted in increasing order of eigenvalues
    
    
    """
    Cov_ss = np.cov(X_ns.T.astype(np.float64))
    Vals,Vecs = np.linalg.eigh(Cov_ss)
    return Vecs.astype(np.float32).T[np.argsort(Vals)[::-1]]

def compute_feats_lda(clu_means,clu_covs):
    """
    Compute features using an LDA-inspired method.
    Reference: Bishop, PRML 1ed, section 4.1.7
    Let S_W and S_B be within-cluster and between-cluster covariance matrices. Components are taken to be eigenvectors of S_W^-1*S_B
    """
    
    # normalize each covariance by its trace so one noisy cluster doesn't have undue influence. Is this a good idea?
    S_W = sum(clu_cov/np.trace(clu_cov) for clu_cov in clu_covs)
    S_W += .1*np.eye(S_W.shape[0])
    S_B = np.cov(np.array(clu_means).T)
    #Vals,Vecs = np.linalg.eig(np.dot(np.linalg.inv(S_W),S_B))
    Vals,Vecs = np.linalg.eig(S_B)

    return Vecs.T[np.argsort(Vals)[::-1]]
    

def make_lda_features_from_spks(SpkFileNames):
    "all files should have been generated by the same parameters"

    clu_means, clu_covs = [],[]
    pars0 = n_ch,sample_rate,s_before,s_after = get_pars_from_xml2(switch_ext(SpkFileNames[0],"xml"))
    print "parameters: %s"%str(pars0[1:])
    
    for fname in SpkFileNames:
        pars = get_pars_from_xml2(switch_ext(fname,"xml"))
        if pars[1:] != pars0[1:]:
            print "%s has wrong parameters. skipping"%fname
            continue
        n_ch = pars[0]
        X_nsc = read_spk(fname,n_ch,s_after+s_before)
        Clu_n = read_clu(switch_ext(fname,"clu.1"))
        for i_ch in xrange(n_ch):
            Mean_ms, Cov_mss, Count_m = get_clu_params(Clu_n,X_nsc[:,:,i_ch])
            for mean,cov,count in zip(Mean_ms, Cov_mss, Count_m):
                if count > 100: 
                    clu_means.append(mean)
                    clu_covs.append(cov)
    
    Feat_ss = compute_feats_lda(clu_means,clu_covs)
    save_feature_info(Feat_ss,s_before,s_after,sample_rate)
    

def save_feature_info(Feats_ss,s_before,s_after,sample_rate):
    """
    Get features from data and save to features.txt
    
    Parameters
    ----------------
    X_ss : array. each row is a feature vector.
    s_before, s_after : int
    sample_rate : float    
    """

    TS_s = np.arange(-s_before,s_after)/sample_rate
    np.savetxt(feature_path,Feats_ss)
    np.savetxt(ts_path,TS_s)

def make_pca_features_from_spk(SpkFileName):
    """
    Make pca features from a .spk file and save them to features.txt    
    """
    
    SpkDir = os.path.dirname(SpkFileName)
    n_ch,sample_rate,s_before,s_after = get_pars_from_xml2(find_file_with_ext(SpkDir,"xml",True))        
    X_ns = np.fromfile(SpkFileName,dtype=np.int16).reshape(-1,s_before+s_after,n_ch)[:,:,0]
    Feats_ss = compute_pcs(X_ns)
    save_feature_info(Feats_ss,s_before,s_after,sample_rate)

    
def get_clu_params(Class_n,X_nf):
    Count_m = np.bincount(Class_n)
    M = Count_m.size
    X_nf = X_nf.astype(np.float32)
    Count_m = Count_m.astype(np.int32)
    #N,F = X_nf.shape
    Mean_mf = class_means(X_nf,Class_n,M)
    Cov_mff = class_covs(X_nf,Mean_mf,Class_n,M)
    return Mean_mf, Cov_mff, Count_m
    
def get_features(s_before,s_after,sample_rate,F):
    """
    Create feature vectors with appropriate sampling interval from features saved in features.txt. 
    
    Parameters
    -----------------
    s_before, s_after, sample_rate
    F : number of features
    
    Returns
    -----------
    """
    Feats_fs = np.loadtxt(feature_path)
    TS_s = np.loadtxt(ts_path)
    return np.array([np.interp(np.arange(-s_before,s_after,dtype=np.float32)/sample_rate,TS_s,Feat_s) for Feat_s in Feats_fs]).astype(np.float32)[:F]

def plot_features(F=None,savefig=False,output_dir=None):
    """
    Plot features from features.txt
    
    Parameters
    -------------------
    F : number of waveforms to plot
    savefig : boolean. If true, then save figure in output_dir. Otherwise plot immediately.
    """
    import matplotlib.pyplot as plt
    TS_s = np.loadtxt(ts_path)
    Feats_fs = np.loadtxt(feature_path)[:F]
    for Feat_s in Feats_fs:
        plt.plot(TS_s,Feat_s)
    plt.legend([str(i) for i in range(F)])
    if savefig:
        outpath = os.path.join(output_dir or ".","features.png")
        plt.savefig(outpath)
    else:
        plt.show()    
        
if __name__ == "__main__":
    from glob import glob
    spkfiles = glob("/home/joschu/Data/d11221_saved/*batch/*.spk.1")
    make_lda_features_from_spks(spkfiles)